import configparser
import sys
import timeit
from pathlib import Path
import numpy as np 
import pandas as pd 
from fair_clustering_large_cluster import fair_clustering_large_cluster
from util.configutil import read_list
from util.utilhelpers import max_Viol_multi_color, x_for_colorBlind, max_RatioViol_multi_color, find_balance_multi_color, max_Viol_Normalized_multi_color
from add_viol_func import get_GF_max_additive_violation , get_DS_max_additive_violation


# Following variables are added from previous coode and not used. Do NOT Modify 
LowerBound = 0 
ml_model_flag = True
p_acc = 1.0 


# k0: is the first cluster size 
k0= 7

# kend: is the last cluster size 
kend= 8



config_file = "config/example_large_cluster_config.ini"
config = configparser.ConfigParser(converters={'list': read_list})
config.read(config_file)
config_str = "adult_sex" if len(sys.argv) == 1 else sys.argv[1]


# Read variables
data_dir = config[config_str].get("data_dir")
dataset = config[config_str].get("dataset")
clustering_config_file = config[config_str].get("config_file")
num_cluster = list(map(int, config[config_str].getlist("num_clusters")))
deltas = list(map(float, config[config_str].getlist("deltas")))
max_points = config[config_str].getint("max_points")


max_points = 20000




# ready up for the loop 
clusters = [ k+k0 for k in list(range(kend-k0+1))]
df = pd.DataFrame(columns=['num_clusters','POF_GF','POF_DS','POF_doubly_GF','POF_doubly_DS','additive_viol_colorBlind','additive_viol_GF','additive_viol_DS','additive_viol_doubly_GF','additive_viol_doubly_DS','ds_additive_viol_colorBlind','ds_additive_viol_GF','ds_additive_viol_DS','ds_additive_viol_dGF','ds_additive_viol_dDS','GF_emptycluster_flag','DS_emptycluster_flag','ColorBlindTime','GF_Time','DS_Time','doubly_GF_Time','doubly_DS_Time','doublyGF_ratio','doublyDS_ratio'])

iter_idx = 0 





for cluster in clusters:
    start_time = timeit.default_timer()
    output = fair_clustering_large_cluster(dataset, clustering_config_file, data_dir, cluster, deltas, max_points, LowerBound, p_acc, ml_model_flag)
    elapsed_time = timeit.default_timer() - start_time


    num_points = output['num_points']
    scaling = output["scaling"] 


    # Get the costs 
    colorBlind_cost = output['unfair_cost']
    GF_cost = output['GF_cost']
    DS_cost = output['DS_cost']
    doubly_GF_cost = output["doublyGF_cost"] 
    doubly_DS_cost = output["doublyDS_cost"] 
    POF_GF = GF_cost/colorBlind_cost
    POF_DS = DS_cost/colorBlind_cost
    POF_doubly_GF = doubly_GF_cost/colorBlind_cost
    POF_doubly_DS = doubly_DS_cost/colorBlind_cost



    num_colors = output['num_colors']
    alpha = output['alpha']
    beta = output['beta']
    num_colors =  output["num_colors"]


    x_GF = output['GF_assignment'] 
    x_DS = output['DS_assignment']
    x_doubly_GF  = output["doublyGF_assignment"] 
    x_doubly_DS  = output["doublyDS_assignment"] 

    dGF_num_clusters_active =  output['dGF_num_clusters_active'] 
    dDS_num_clusters_active =  output['dDS_num_clusters_active'] 


    x_color_blind = x_for_colorBlind(output['unfair_assignments'],cluster)
    prob_vecs = output['prob_vecs'] 
    prob_vecs = np.reshape(prob_vecs, (-1,num_colors)) 


    color_flag = output['color_flag']


    additive_viol_colorBlind = get_GF_max_additive_violation(x_color_blind, color_flag, cluster, num_colors, alpha, beta)
    additive_viol_GF = get_GF_max_additive_violation(x_GF, color_flag, cluster, num_colors, alpha, beta)
    additive_viol_DS = get_GF_max_additive_violation(x_DS, color_flag, cluster, num_colors, alpha, beta)
    additive_viol_doubly_GF = get_GF_max_additive_violation(x_doubly_GF, color_flag, dGF_num_clusters_active, num_colors, alpha, beta)
    additive_viol_doubly_DS = get_GF_max_additive_violation(x_doubly_DS, color_flag, dDS_num_clusters_active, num_colors, alpha, beta)


    color_blind_center_colors =   output["ColorBlind_center_colors"]  
    GF_center_colors =   output["GF_center_colors"] 
    DS_center_colors =   output["DS_center_colors"] 
    doublyGF_center_colors =   output["doublyGF_center_colors"] 
    doublyDS_center_colors =   output["doublyDS_center_colors"] 


    centerLowerBound=  output["DS_lowerBounds"] 
    centerUpperBound=  output["DS_upperBounds"] 


    # Might need to change num_cluster to another variabke 
    ds_additive_viol_colorBlind = get_DS_max_additive_violation(color_blind_center_colors,num_cluster, num_colors, centerUpperBound, centerLowerBound)
    ds_additive_viol_GF = get_DS_max_additive_violation(GF_center_colors,num_cluster, num_colors, centerUpperBound, centerLowerBound)
    ds_additive_viol_DS = get_DS_max_additive_violation(DS_center_colors,num_cluster, num_colors, centerUpperBound, centerLowerBound)
    ds_additive_viol_dGF = get_DS_max_additive_violation(doublyGF_center_colors,num_cluster, num_colors, centerUpperBound, centerLowerBound)
    ds_additive_viol_dDS = get_DS_max_additive_violation(doublyDS_center_colors,num_cluster, num_colors, centerUpperBound, centerLowerBound)



    # Get the empty cluster flags 
    GF_emptycluster_flag = output['GF_emptycluster_flag']  
    DS_emptycluster_flag = output['DS_emptycluster_flag']  



    # Record the times 
    cluster_time = output["ColorBlind_time"] 
    GF_time = output["GF_time"] 
    DS_time =  output["DS_time"] 
    doubly_GF_time = output["doubly_GF_time"]  
    doubly_DS_time =  output["doubly_DS_time"] 

    doublyGF_ratio = doubly_GF_time/GF_time 
    doublyDS_ratio = doubly_DS_time/DS_time 




    df.loc[iter_idx] = [cluster,POF_GF,POF_DS,POF_doubly_GF,POF_doubly_DS,additive_viol_colorBlind,additive_viol_GF,additive_viol_DS,additive_viol_doubly_GF,additive_viol_doubly_DS,ds_additive_viol_colorBlind,ds_additive_viol_GF,ds_additive_viol_DS,ds_additive_viol_dGF,ds_additive_viol_dDS,GF_emptycluster_flag,DS_emptycluster_flag,cluster_time,GF_time,DS_time,doubly_GF_time,doubly_DS_time,doublyGF_ratio,doublyDS_ratio]

    iter_idx += 1 


scale_flag = 'normalized' if scaling else 'unnormalized' 
filename = dataset + '_' + 'kcenter' + '_' + str(int(num_points)) + '_' + scale_flag  
filename = filename + '.csv'


# do not over-write 
filepath = Path('Results' + '/'+ filename)
while filepath.is_file():
    filename='new' + filename 
    filepath = Path('Results' + '/'+ filename)

df.to_csv('Results' + '/'+ filename, sep=',',index=False)